package com.test.controller;

import com.baomidou.mybatisplus.core.toolkit.IdWorker;
import com.test.utils.AssertUtil;
import com.test.utils.LogWorker;
import lombok.AllArgsConstructor;
import lombok.Getter;
import org.slf4j.event.Level;
import org.springframework.lang.NonNull;
import org.springframework.transaction.PlatformTransactionManager;
import org.springframework.transaction.TransactionDefinition;
import org.springframework.transaction.TransactionStatus;
import org.springframework.transaction.annotation.Isolation;
import org.springframework.transaction.support.DefaultTransactionDefinition;

import java.util.concurrent.*;

/**
 * @author kang
 * @description 多线程事务
 * 思路：类似2PC（二阶段提交）：
 * 一阶段：各个子任务执行业务逻辑（异步），告知执行结果（成功/失败）；
 * 二阶段：如果一阶段中所有子任务都执行成功，则提交所有子任务的本地事务，否则回滚所有事务（存在任意子任务失败时）；
 * @date 2021/6/16
 */
public class AsyncTransactionManager {

    private static final LogWorker LOG_WORKER = LogWorker.of(AsyncTransactionManager.class, "ATC");

    /**
     * 整体任务执行状态
     */
    private volatile int state;

    /**
     * 子任务数量
     */
    private int transactionCount = 0;

    /**
     * 开启线程优先级。按照任务提交顺序设置优先级
     */
    private Boolean enablePriority;

    /**
     * 事务隔离级别
     */
    private Isolation isolation;

    /**
     * 支持异步非阻塞、背压的响应式消息发布工具，用于向订阅方广播发布消息（子任务调度）
     */
    private SubmissionPublisher<Payload> publisher;

    /**
     * Spring事务管理器
     */
    private PlatformTransactionManager transactionManager;

    /**
     * 用于阻塞所有子任务直到全部执行结束（成功或者失败）
     */
    private CyclicBarrier cyclicBarrier;

    /**
     * 用于阻塞主线程直到所有子任务全部执行结束
     */
    private CountDownLatch countDownLatch;

    /**
     * 整体任务过期时间（单位：秒）
     */
    private long taskTimeOut;

    /**
     * 请求唯一标识，方便日志定位
     */
    private Long rid;

    /**
     * 私有化构造器，避免外部直接实例化
     */
    private AsyncTransactionManager() {

    }

    public static Builder builder() {
        return new Builder();
    }


    /**
     * 构建实例信息
     */
    public static class Builder {

        private PlatformTransactionManager transactionManager;

        /**
         * 使用外部线程池执行子任务
         */
        private Executor executor;

        private Boolean enablePriority = false;

        private long taskTimeOut = 15L;

        private Isolation isolation = Isolation.DEFAULT;

        /**
         * SubmissionPublisher的消息缓存区大小，需要设置成一个2的n次幂。（目前场景只会发布一个消息，缓冲区大小设置为1）
         */
        private final int maxBufferCapacity = 1;

        public Builder executor(@NonNull Executor executor) {
            this.executor = executor;
            return this;
        }

        public Builder transactionManager(@NonNull PlatformTransactionManager transactionManager) {
            this.transactionManager = transactionManager;
            return this;
        }

        public Builder isolation(@NonNull Isolation isolation) {
            this.isolation = isolation;
            return this;
        }

        public Builder taskTimeOut(@NonNull long taskTimeOut) {
            this.taskTimeOut = taskTimeOut;
            return this;
        }

        public Builder enablePriority(@NonNull Boolean enable) {
            this.enablePriority = enable;
            return this;
        }

        public AsyncTransactionManager build() {
            AssertUtil.isNull(this.transactionManager, "transactionManager cannot be null!");
            var manager = new AsyncTransactionManager();
            manager.enablePriority = this.enablePriority;
            manager.transactionManager = this.transactionManager;
            manager.isolation = this.isolation;
            manager.state = State.INIT;
            manager.rid = IdWorker.getId();
            manager.taskTimeOut = this.taskTimeOut;
            if (this.executor == null) {
                manager.publisher = new SubmissionPublisher();
            } else {
                manager.publisher = new SubmissionPublisher(this.executor, this.maxBufferCapacity);
            }
            return manager;
        }

    }

    /**
     * 提交子任务，并监听事务消息进行处理
     */
    public AsyncTransactionManager run(@NonNull final Task task) {
        return this.run(task, taskTimeOut);
    }

    /**
     * 提交子任务，并监听事务消息进行处理
     * @param task      子任务
     * @param timeout   执行超时时间
     */
    public AsyncTransactionManager run(@NonNull final Task task, @NonNull long timeout) {
        final var taskPriority = Math.min(++transactionCount, Thread.MAX_PRIORITY);
        publisher.subscribe(new Flow.Subscriber<>() {

            private TransactionStatus transactionStatus;

            private Flow.Subscription subscription;

            /**
             * 订阅回调
             */
            @Override
            public void onSubscribe(final Flow.Subscription subscription) {
                LOG_WORKER.log(rid, Level.DEBUG, "received one subscribe");
                if (enablePriority) {
                    Thread.currentThread().setPriority(taskPriority);
                }
                /* 向上游发布者订阅一条消息 */
                (this.subscription = subscription).request(1);
            }

            /**
             * @description 监听子任务执行消息
             * @param payload 消息参数
             */
            @Override
            public void onNext(Payload payload) {
                /* 可能该任务线程刚抢到cpu执行权时已经存在其他子任务失败了，这里做个快速失败。中断当前线程也可以唤醒其他等待中的线程，一起提前回滚而无需等待，尽快释放事务。 */
                if (state != State.RUNNING) {
                    Thread.currentThread().interrupt();
                }
                try {
                    /* 开启子线程本地事务 */
                    var definition = new DefaultTransactionDefinition(TransactionDefinition.PROPAGATION_REQUIRES_NEW);
                    definition.setIsolationLevel(isolation.value());
                    this.transactionStatus = transactionManager.getTransaction(definition);
                    task.accept();
                    /* 子任务执行完毕，到达屏障等待其他子任务 */
                    cyclicBarrier.await(timeout, TimeUnit.SECONDS);
                } catch (Exception e) {
                    LOG_WORKER.log(rid, Level.ERROR, "current task execution failed! ", e);
                    stateForward(Action.ROLLBACK);
                    /* 任务执行失败，立即中断并唤醒其他等待中的线程进行回滚 */
                    Thread.currentThread().interrupt();
                } finally {
                    try {
                        /* 提交或回滚事务 */
                        this.processTransaction();
                        /* 停止消息监听 */
                        subscription.cancel();
                    } finally {
                        countDownLatch.countDown();
                    }
                }
            }

            @Override
            public void onError(Throwable throwable) {
                LOG_WORKER.log(rid, Level.ERROR, "something error happened! ", throwable);
            }

            @Override
            public void onComplete() {
                LOG_WORKER.log(rid, Level.DEBUG, "current task has been completed.");
            }

            /**
             * 对子任务事务进行提交或回滚
             */
            private void processTransaction() {
                LOG_WORKER.log(rid, Level.DEBUG, "ready to finish the transaction, task status: {}", state);
                switch (state) {
                    /* 只有所有子线程全部执行成功，才会进行事务提交 */
                    case State.SUCCESS:
                        transactionManager.commit(transactionStatus);
                        break;
                    /* 任意一个子线程执行失败，回滚所有事务 */
                    case State.FAIL:
                        if (!transactionStatus.isCompleted()) {
                            transactionManager.rollback(transactionStatus);
                        }
                        break;
                    default:
                        LOG_WORKER.log(rid, Level.ERROR, "task executed[{}] failed! transaction has been completed.", state);
                }
            }

        });
        return this;
    }

    /**
     * 开始执行所有子任务
     *
     * @return State 最终执行状态
     */
    public int execute() {
        AssertUtil.isFalse(this.transactionCount > 0, "请添加至少一个子任务");
        stateForward(Action.RUN);
        /* 设置事务屏障，所有子任务全部执行成功才会统一提交各个事务 */
        this.cyclicBarrier = new CyclicBarrier(this.transactionCount, () -> {
            LOG_WORKER.log(rid, Level.DEBUG, "all tasks[{}] has been executed successful, prepare to commit all transactions.", transactionCount);
            /* 子任务全部执行成功，设置任务执行状态（注意：此时各个子线程仍在await，执行该回调后才会去执行await后面的逻辑） */
            stateForward(Action.COMMIT);
        });
        /* 通过countDownLatch阻塞主线程，直到所有子任务执行结束 */
        this.countDownLatch = new CountDownLatch(this.transactionCount);
        /* 发布任务执行消息，各个子任务开始异步执行 */
        publisher.submit(Payload.of(null));
        try {
            LOG_WORKER.log(rid, Level.DEBUG, "main thread is waitting for all of the tasks to finish transactions.");
            this.countDownLatch.await(taskTimeOut, TimeUnit.SECONDS);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        stateForward(Action.FINISH);
        LOG_WORKER.log(rid, Level.DEBUG, "task execution finished! result: {}", this.state);
        return this.state;
    }


    /**
     * 消息载体（暂时用不着，预留）
     */
    @Getter
    @AllArgsConstructor
    protected static class Payload {

        private TransactionStatus transactionStatus;

        public static Payload of(TransactionStatus transactionStatus) {
            return new Payload(transactionStatus);
        }

    }

    /**
     * 任务事件。用于实现简易的事务状态机
     */
    @Getter
    @AllArgsConstructor
    protected enum Action {

        /**
         * 开启事务
         */
        RUN(State.INIT, State.RUNNING),

        /**
         * 提交事务
         */
        COMMIT(State.RUNNING, State.SUCCESS),

        /**
         * 回滚事务
         */
        ROLLBACK(State.RUNNING | State.FAIL, State.FAIL),

        /**
         * 结束
         */
        FINISH(State.SUCCESS | State.FAIL, State.DONE);

        /**
         * 当前状态集
         */
        private int start;

        /**
         * 目标状态
         */
        private int next;

    }

    /**
     * 事务状态
     */
    public class State {

        public static final int INIT    =   0b00000001;
        public static final int RUNNING =   INIT << 1;
        public static final int SUCCESS =   INIT << 2;
        public static final int FAIL    =   INIT << 3;
        public static final int DONE    =   INIT << 4;

    }

    /**
     * 通过事件推进状态进行扭转
     */
    private int stateForward(Action action) {
        AssertUtil.isTrue((action.getStart() & state) == 0, "事务状态异常，预期： {}， 当前： {}", action.getStart(), state);
        return state = action.getNext();
    }

    /**
     * 子任务
     */
    @FunctionalInterface
    public interface Task {

        void accept();

    }

}
